{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 구조화된 생성으로 근거 강조 표시가 있는 RAG 시스템 구축하기\n",
    "_작성자: [Aymeric Roucher](https://huggingface.co/m-ric), 번역: [유용상](https://huggingface.co/4n3mone)_\n",
    "\n",
    "**구조화된 생성**(Structured generation)은 LLM 출력이 특정 패턴을 따르도록 강제하는 방법입니다.\n",
    "\n",
    "이 방법은 여러 가지 용도로 사용될 수 있습니다:\n",
    "- ✅ 특정 키가 있는 딕셔너리 출력\n",
    "- 📏 출력이 N글자 이상이 되도록 보장\n",
    "- ⚙️ 더 일반적으로, 다운스트림 처리를 위해 출력이 특정 정규 표현식 패턴을 따르도록 강제\n",
    "- 💡 검색 증강 생성(RAG)에서 답변을 뒷받침하는 소스를 강조 표시\n",
    "\n",
    "이 노트북은 마지막 예시를 구체적으로 보여줍니다.\n",
    "\n",
    "**➡️ 우리는 답변을 제공할 뿐만 아니라 이 답변의 근거가 되는 스니펫을 강조 표시하는 RAG 시스템을 구축합니다.**\n",
    "\n",
    "_RAG에 대한 소개가 필요하다면, [이 쿡북](advanced_rag)을 확인해 보세요._\n",
    "\n",
    "이 노트북은 먼저 프롬프트를 통한 구조화된 생성의 단순한 접근 방식을 보여주고 그 한계를 강조한 다음, 더 효율적인 구조화된 생성을 위한 제한된 디코딩(constrained decoding)을 시연합니다.\n",
    "\n",
    "이 노트북은 HuggingFace Inference Endpoints를 활용합니다 (예제는 [서버리스](https://huggingface.co/docs/api-inference/quicktour) 엔드포인트를 사용하지만, [전용](https://huggingface.co/docs/inference-endpoints/en/guides/access) 엔드포인트로 변경할 수 있습니다), 또한 [outlines](https://github.com/outlines-dev/outlines)라는 구조화된 텍스트 생성 라이브러리를 사용한 로컬 추론 예제도 보여줍니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "id": "5prqzyu6zyVg"
   },
   "outputs": [],
   "source": [
    "!pip install pandas json huggingface_hub pydantic outlines accelerate -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "id": "pxIb4wz0zyVg"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "from huggingface_hub import InferenceClient\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "id": "8GxOlj0czyVh",
    "outputId": "7315edac-a7c1-4608-cd55-6366d7e27515"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' 서울특별시입니다.'"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "repo_id = \"mistralai/Mistral-Nemo-Instruct-2407\"\n",
    "\n",
    "llm_client = InferenceClient(model=repo_id, timeout=120)\n",
    "\n",
    "# Test your LLM client\n",
    "llm_client.text_generation(prompt=\"대한민국의 수도는?\", max_new_tokens=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 모델에 프롬프트 제공하기\n",
    "\n",
    "모델에서 구조화된 출력을 얻으려면, 충분히 성능이 좋은 모델에 적절한 지시사항을 포함한 프롬프트를 제공하면 됩니다. 대부분의 경우 이 방법이 잘 작동할 것입니다.\n",
    "\n",
    "이번 경우, 우리는 RAG 모델이 답변뿐만 아니라 신뢰도 점수와 근거가 되는 스니펫도 함께 생성하기를 원합니다.\n",
    "\n",
    "이러한 출력을 JSON 형식의 딕셔너리로 생성하면, 나중에 쉽게 처리할 수 있습니다 (여기서는 근거가 되는 스니펫을 강조하여 표시할 예정입니다)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "RELEVANT_CONTEXT = \"\"\"\n",
    "문서:\n",
    "\n",
    "오늘 서울의 날씨가 정말 좋네요.\n",
    "Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "RAG_PROMPT_TEMPLATE_JSON= \"\"\"문서를 기반으로 사용자 쿼리에 응답합니다.\n",
    "\n",
    "다음은 문서입니다: {context}\n",
    "\n",
    "\n",
    "답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n",
    "근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n",
    "\n",
    "답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n",
    "\n",
    "Answer:\n",
    "{{\n",
    "  “answer\": 정답 문장,\n",
    "  “confidence_score\": 신뢰도 점수,\n",
    "  “source_snippets\": [“근거_1”, “근거_2”, ...]\n",
    "}}\n",
    "End of answer.\n",
    "\n",
    "이제 시작하세요!\n",
    "다음은 사용자 질문입니다: {user_query}.\n",
    "Answer:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "USER_QUERY = \"Transformers에서 정지 시퀀스를 어떻게 정의하나요?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "id": "QIrMKgBzzyVi",
    "outputId": "a4c92c0b-ed15-43aa-82a3-8ac23c28f172"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "문서를 기반으로 사용자 쿼리에 응답합니다.\n",
      "\n",
      "다음은 문서입니다: \n",
      "문서:\n",
      "\n",
      "오늘 서울의 날씨가 정말 좋네요.\n",
      "Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n",
      "근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n",
      "\n",
      "답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n",
      "\n",
      "Answer:\n",
      "{\n",
      "  “answer\": 정답 문장,\n",
      "  “confidence_score\": 신뢰도 점수,\n",
      "  “source_snippets\": [“근거_1”, “근거_2”, ...]\n",
      "}\n",
      "End of answer.\n",
      "\n",
      "이제 시작하세요!\n",
      "다음은 사용자 질문입니다: Transformers에서 정지 시퀀스를 어떻게 정의하나요?.\n",
      "Answer:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n",
    "    context=RELEVANT_CONTEXT, user_query=USER_QUERY\n",
    ")\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "id": "JZtnTrSqzyVi",
    "outputId": "83295148-21db-4cdf-d557-491d7c457358"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"answer\": \"Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\",\n",
      "  \"confidence_score\": 0.95,\n",
      "  \"source_snippets\": [\"정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\"]\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    max_new_tokens=256,\n",
    ")\n",
    "\n",
    "answer = answer.split(\"End of answer.\")[0]\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LLM의 출력은 딕셔너리의 문자열 표현입니다. 따라서 `literal_eval`을 사용하여 이를 딕셔너리로 로드합시다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {
    "id": "sadeCc1JzyVj"
   },
   "outputs": [],
   "source": [
    "from ast import literal_eval\n",
    "\n",
    "parsed_answer = literal_eval(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "id": "lPubGIpFzyVj",
    "outputId": "7f458548-5f0e-40dd-acd4-91897fc3f737"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer: \u001b[1;32mTransformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n",
      "\n",
      "\n",
      " ========== Source documents ==========\n",
      "\n",
      "문서:\n",
      "\n",
      "오늘 서울의 날씨가 정말 좋네요.\n",
      "Transformers에서 \u001b[1;32m정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def highlight(s):\n",
    "    return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n",
    "\n",
    "\n",
    "def print_results(answer, source_text, highlight_snippets):\n",
    "    print(\"Answer:\", highlight(answer))\n",
    "    print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n",
    "    for snippet in highlight_snippets:\n",
    "        source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n",
    "    print(source_text)\n",
    "\n",
    "\n",
    "print_results(\n",
    "    parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "잘 작동합니다! 🥳\n",
    "\n",
    "하지만 성능이 낮은 모델을 사용하는 경우는 어떨까요?\n",
    "\n",
    "성능이 떨어지는 모델의 불안정한 출력을 시뮬레이션하기 위해, temperature 값을 높여보겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "id": "eNWhbK0KzyVj",
    "outputId": "6327cdb6-7f8b-40c6-cf32-546dff51f6e8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"answer\": adjectistiques Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор reminding frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna７ Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust St I With 모 clans Eddy Bindingtsoke funeral Stefano authenticitatcontent。\n",
      "\n",
      "적으로ებულიização finnotes fins witCamera 하나 ls Metallurne couleur platinum/c وأنت textarea Golfyyzuhalten assume prog_reset\"Piagn Ameth amivio COR '',\n",
      "ze Columbia padchart\": Poul?\"\n",
      "\n",
      "       φsin den Qu tiendas Mister�cling tercero política’avenir emploi banque inertکا …\n",
      "anic lucommon-contagsbor ruvisending frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna７ Mons ram incColumn Orth masses frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna７ Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust\n"
     ]
    }
   ],
   "source": [
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    max_new_tokens=250,\n",
    "    temperature=1.6,\n",
    "    return_full_text=False,\n",
    ")\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "출력이 올바른 JSON 형식조차 아닌 것을 확인할 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 👉 제한된 디코딩(Constrained decoding)\n",
    "\n",
    "JSON 출력을 강제하기 위해, 우리는 **제한된 디코딩**을 사용해야 합니다. 여기서 LLM이 **문법**이라고 불리는 일련의 규칙에 맞는 토큰만 출력하도록 강제합니다.\n",
    "\n",
    "이 문법은 Pydantic 모델, JSON 스키마 또는 정규 표현식을 사용하여 정의할 수 있습니다. 그러면 AI는 지정된 문법에 맞는 응답을 생성합니다.\n",
    "\n",
    "예를 들어, 여기서는 [Pydantic 타입](https://docs.pydantic.dev/latest/api/types/)을 따릅니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {
    "id": "7NQAnQ7hzyVj"
   },
   "outputs": [],
   "source": [
    "from pydantic import BaseModel, confloat, StringConstraints\n",
    "from typing import List, Annotated\n",
    "\n",
    "\n",
    "class AnswerWithSnippets(BaseModel):\n",
    "    answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n",
    "    confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n",
    "    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xa-6v1U9zyVj"
   },
   "source": [
    "생성된 스키마가 요구 사항을 올바르게 나타내는지 확인해 보세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "id": "gInE3OtqzyVj",
    "outputId": "f9cdb85c-390e-458c-f1b1-28853f947a0e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'properties': {'answer': {'maxLength': 100,\n",
       "   'minLength': 10,\n",
       "   'title': 'Answer',\n",
       "   'type': 'string'},\n",
       "  'confidence': {'title': 'Confidence', 'type': 'number'},\n",
       "  'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n",
       "   'title': 'Source Snippets',\n",
       "   'type': 'array'}},\n",
       " 'required': ['answer', 'confidence', 'source_snippets'],\n",
       " 'title': 'AnswerWithSnippets',\n",
       " 'type': 'object'}"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "AnswerWithSnippets.schema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "클라이언트의 `text_generation` 메서드를 사용하거나 `post` 메서드를 사용할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "id": "NJW3Op7czyVj",
    "outputId": "c0d85a5e-a1ea-4332-d2eb-6643ebd80740"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"answer\": \" neces恨bay внеpok Archives-Common Propsogs’organpern 공격forschfläche elicous neces恨bay внеpok món-�\",\"confidence\": 1,\"source_snippets\": [\"Washington Roman Humналеualion\", \"_styleImplementedAugust lire\",\n",
      "  \"\"]\n",
      "\n",
      "                                                            }\n",
      "{\n",
      "  \"answer\": \" بخopuerto կար因數 kavuts mi Firefox Penguins er sdபெர erinnert publiée 물리 DK\\({}^{\\ Cis بخopuerto կար因數\"\n",
      ",\n",
      "  \"confidence\": 0.7825484027713585\n",
      ",\n",
      "  \"source_snippets\": [\n",
      "\n",
      "\"Transformerграни moisady отгaನ\", \", migrations ceproductionautal\",\n",
      "\"Listeners accelerating loocae\"\n",
      "]\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Using text_generation\n",
    "answer = llm_client.text_generation(\n",
    "    prompt,\n",
    "    grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
    "    max_new_tokens=250,\n",
    "    temperature=1.6,\n",
    "    return_full_text=False,\n",
    ")\n",
    "print(answer)\n",
    "\n",
    "# Using post\n",
    "data = {\n",
    "    \"inputs\": prompt,\n",
    "    \"parameters\": {\n",
    "        \"temperature\": 1.6,\n",
    "        \"return_full_text\": False,\n",
    "        \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
    "        \"max_new_tokens\": 250,\n",
    "    },\n",
    "}\n",
    "answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ 높은 temperature 설정으로 인해 답변 내용은 여전히 말이 되지 않지만, 생성된 출력 텍스트는 이제 우리가 문법에서 정의한 정확한 키와 자료형을 가진 올바른 JSON 형식입니다!\n",
    "\n",
    "이제 이 출력물을 추가 처리를 위해 파싱할 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Outlines를 사용해서 로컬 환경에서 문법 활용하기\n",
    "\n",
    "[Outlines](https://github.com/outlines-dev/outlines/)는 Hugging Face의 Inference API에서 출력 생성을 제한하기 위해 내부적으로 실행되는 라이브러리입니다. 이를 로컬 환경에서도 사용할 수 있습니다.\n",
    "\n",
    "이 라이브러리는 [로짓(logits)에 편향(bias)을 적용하는 방식](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143)으로 작동하여, 사용자가 정의한 제약 조건에 부합하는 선택지만 강제로 선택되도록 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"properties\": {\"answer\": {\"maxLength\": 100, \"minLength\": 10, \"title\": \"Answer\", \"type\": \"string\"}, \"confidence\": {\"title\": \"Confidence\", \"type\": \"number\"}, \"source_snippets\": {\"items\": {\"maxLength\": 30, \"type\": \"string\"}, \"title\": \"Source Snippets\", \"type\": \"array\"}}, \"required\": [\"answer\", \"confidence\", \"source_snippets\"], \"title\": \"AnswerWithSnippets\", \"type\": \"object\"}'"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "schema_as_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HNb1UeZSzyVk"
   },
   "outputs": [],
   "source": [
    "import outlines\n",
    "\n",
    "repo_id = \"Qwen/Qwen2-7B-Instruct\"\n",
    "# 로컬에서 모델 로드하기\n",
    "model = outlines.models.transformers(repo_id)\n",
    "\n",
    "schema_as_str = json.dumps(AnswerWithSnippets.schema())\n",
    "\n",
    "generator = outlines.generate.json(model, schema_as_str)\n",
    "\n",
    "# Use the `generator` to sample an output from the model\n",
    "result = generator(prompt)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "제약 생성(constrained generation)을 사용하여 [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index)를 활용할 수도 있습니다 (자세한 내용과 예시는 [문서](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance)를 참조하세요).\n",
    "\n",
    "지금까지 우리는 특정 RAG 사용 사례를 보여주었지만, 제약 생성은 그 이상으로 많은 도움이 됩니다.\n",
    "\n",
    "예를 들어, [LLM judge](llm_judge) 워크플로우에서도 제약 생성을 사용하여 다음과 같은 JSON을 출력할 수 있습니다:\n",
    "```py\n",
    "{\n",
    "    \"score\": 1,\n",
    "    \"rationale\": \"The answer does not match the true answer at all.\",\n",
    "    \"confidence_level\": 0.85\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fEhBMgK4zyVk"
   },
   "source": [
    "오늘은 여기까지입니다. 끝까지 따라와 주셔서 감사드립니다! 👏"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
